import torch

 # 这是一个将索引转换为独热编码的函数。
def idx2onehot(idx, n):

    assert torch.max(idx).item() < n  # 确保索引值不超过n

    if idx.dim() == 1:  # 如果输入是一维的，则在第2维上增加一个维度
        idx = idx.unsqueeze(1)  # 增加一个维度
    onehot = torch.zeros(idx.size(0), n).to(idx.device)  # 构造一个全0的矩阵，大小为(batch_size, n)
    onehot.scatter_(1, idx, 1) # 将对应位置的值设为1
    
    return onehot  # 返回独热编码的结果
